# -*- coding: UTF-8 -*-
"""
@author: hhyo、yyukai
@license: Apache Licence
@file: redis.py
@time: 2019/03/26
"""

import json
import re
import shlex

import redis
import logging
import traceback

from common.utils.timer import FuncTimer
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult

__author__ = "hhyo"

logger = logging.getLogger("default")


class RedisEngine(EngineBase):
    def get_connection(self, db_name=None):
        db_name = db_name or self.db_name
        if self.mode == "cluster":
            return redis.cluster.RedisCluster(
                host=self.host,
                port=self.port,
                username=self.user,
                password=self.password or None,
                encoding_errors="ignore",
                decode_responses=True,
                socket_connect_timeout=10,
                ssl=self.instance.is_ssl,
            )
        else:
            return redis.Redis(
                host=self.host,
                port=self.port,
                db=db_name,
                username=self.user,
                password=self.password or None,
                encoding_errors="ignore",
                decode_responses=True,
                socket_connect_timeout=10,
                ssl=self.instance.is_ssl,
            )

    name = "Redis"

    info = "Redis engine"

    def test_connection(self):
        return self.get_all_databases()

    def get_all_databases(self, **kwargs):
        """
        获取数据库列表
        :return:
        """
        result = ResultSet(full_sql="CONFIG GET databases")
        conn = self.get_connection()
        try:
            rows = conn.config_get("databases")["databases"]
        except Exception as e:
            """
            由于尝试获取databases配置失败，下面的代码块将通过解析info命令的输出来确定数据库的数量。
            失败场景1：AWS-ElastiCache(Redis)服务不支持部分命令行。比如: config get xx, acl 部分命令
            失败场景2：使用了没有管理员权限（-@admin）的Redis用户。 （异常信息：this user has no permissions to run the 'config' command or its subcommand）
            步骤：
            - 通过info("Keyspace")获取所有的数据库键空间信息。
            - 从键空间信息中提取数据库编号（如db0, db1等）。
            - 计算数据库数量，至少会返回0到15共16个数据库。
            """
            logger.warning(f"Redis CONFIG GET databases 执行报错，异常信息：{e}")
            dbs = [
                int(i.split("db")[1])
                for i in conn.info("Keyspace").keys()
                if len(i.split("db")) == 2
            ]
            rows = max(dbs + [15]) + 1

        db_list = [str(x) for x in range(int(rows))]
        result.rows = db_list
        return result

    def get_all_tables(self, db_name, **kwargs):
        """获取表列表。Redis的key可以理为表。方法只扫描部分表。起到预览作用。"""
        result = ResultSet(full_sql="")
        max_results = 100
        table_info_list = []
        try:
            conn = self.get_connection(db_name)
            scan_rows = conn.scan_iter(match=None, count=20)
            for idx, key in enumerate(scan_rows):
                if idx >= max_results:
                    break
                table_info_list.append(key)
        except Exception as e:
            logger.error(f"get_all_tables执行报错，异常信息：{e}")
            result.message = f"{e}"
        result.rows = table_info_list
        return result

    def query_check(self, db_name=None, sql="", limit_num=0):
        """提交查询前的检查"""
        result = {"msg": "", "bad_query": True, "filtered_sql": sql, "has_star": False}
        safe_cmd = [
            "scan",
            "exists",
            "ttl",
            "pttl",
            "type",
            "get",
            "mget",
            "strlen",
            "hgetall",
            "hlen",
            "hexists",
            "hget",
            "hmget",
            "hkeys",
            "hvals",
            "smembers",
            "scard",
            "sdiff",
            "sunion",
            "sismember",
            "llen",
            "lrange",
            "lindex",
            "zrange",
            "zrangebyscore",
            "zscore",
            "zcard",
            "zcount",
            "zrank",
            "info",
        ]
        # 命令校验，仅可以执行safe_cmd内的命令
        for cmd in safe_cmd:
            if re.match(rf"^{cmd}", sql.strip(), re.I):
                result["bad_query"] = False
                break
        if result["bad_query"]:
            result["msg"] = "禁止执行该命令！"
        return result

    def processlist(self, command_type, **kwargs):
        """获取连接信息"""
        sql = "client list"
        result_set = ResultSet(full_sql=sql)
        conn = self.get_connection(db_name=0)
        clients = conn.client_list()
        # 根据空闲时间排序
        sort_by = "idle"
        reverse = False
        clients = sorted(
            clients, key=lambda client: client.get(sort_by), reverse=reverse
        )
        result_set.rows = clients
        return result_set

    def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
        """返回 ResultSet"""
        result_set = ResultSet(full_sql=sql)
        try:
            conn = self.get_connection(db_name=db_name)
            rows = conn.execute_command(*shlex.split(sql))
            result_set.column_list = ["Result"]
            if isinstance(rows, list) or isinstance(rows, tuple):
                if re.match(rf"^scan", sql.strip(), re.I):
                    keys = [[row] for row in rows[1]]
                    keys.insert(0, [rows[0]])
                    result_set.rows = tuple(keys)
                    result_set.affected_rows = len(rows[1])
                else:
                    result_set.rows = tuple([row] for row in rows)
                    result_set.affected_rows = len(rows)
            elif isinstance(rows, dict):
                result_set.column_list = ["field", "value"]
                # 对Redis的返回结果进行类型判断，如果是dict,list转为json字符串。
                pairs_list = []
                for k, v in rows.items():
                    if isinstance(v, dict):
                        processed_value = json.dumps(v)
                    elif isinstance(v, list):
                        processed_value = json.dumps(v)
                    else:
                        processed_value = v
                    # 添加处理后的键值对到列表
                    pairs_list.append([k, processed_value])
                # 将列表转换为元组并赋值给 result_set.rows
                result_set.rows = tuple(pairs_list)
                result_set.affected_rows = len(result_set.rows)
            else:
                result_set.rows = tuple([[rows]])
                result_set.affected_rows = 1 if rows else 0
            if limit_num > 0:
                result_set.rows = result_set.rows[0:limit_num]
        except Exception as e:
            logger.warning(
                f"Redis命令执行报错，语句：{sql}， 错误信息：{traceback.format_exc()}"
            )
            result_set.error = str(e)
        return result_set

    def filter_sql(self, sql="", limit_num=0):
        return sql.strip()

    def query_masking(self, db_name=None, sql="", resultset=None):
        """不做脱敏"""
        return resultset

    def execute_check(self, db_name=None, sql=""):
        """上线单执行前的检查, 返回Review set"""
        check_result = ReviewSet(full_sql=sql)
        split_sql = [cmd.strip() for cmd in sql.split("\n") if cmd.strip()]
        line = 1
        for cmd in split_sql:
            result = ReviewResult(
                id=line,
                errlevel=0,
                stagestatus="Audit completed",
                errormessage="暂不支持显示影响行数",
                sql=cmd,
                affected_rows=0,
                execute_time=0,
            )
            check_result.rows += [result]
            line += 1
        return check_result

    def execute_workflow(self, workflow):
        """执行上线单，返回Review set"""
        sql = workflow.sqlworkflowcontent.sql_content
        split_sql = [cmd.strip() for cmd in sql.split("\n") if cmd.strip()]
        execute_result = ReviewSet(full_sql=sql)
        line = 1
        cmd = None
        try:
            conn = self.get_connection(db_name=workflow.db_name)
            for cmd in split_sql:
                with FuncTimer() as t:
                    conn.execute_command(*shlex.split(cmd))
                execute_result.rows.append(
                    ReviewResult(
                        id=line,
                        errlevel=0,
                        stagestatus="Execute Successfully",
                        errormessage="暂不支持显示影响行数",
                        sql=cmd,
                        affected_rows=0,
                        execute_time=t.cost,
                    )
                )
                line += 1
        except Exception as e:
            logger.warning(
                f"Redis命令执行报错，语句：{cmd or sql}， 错误信息：{traceback.format_exc()}"
            )
            # 追加当前报错语句信息到执行结果中
            execute_result.error = str(e)
            execute_result.rows.append(
                ReviewResult(
                    id=line,
                    errlevel=2,
                    stagestatus="Execute Failed",
                    errormessage=f"异常信息：{e}",
                    sql=cmd,
                    affected_rows=0,
                    execute_time=0,
                )
            )
            line += 1
            # 报错语句后面的语句标记为审核通过、未执行，追加到执行结果中
            for statement in split_sql[line - 1 :]:
                execute_result.rows.append(
                    ReviewResult(
                        id=line,
                        errlevel=0,
                        stagestatus="Audit completed",
                        errormessage=f"前序语句失败, 未执行",
                        sql=statement,
                        affected_rows=0,
                        execute_time=0,
                    )
                )
                line += 1
        return execute_result
